Abstract
Nuestra intención es diseñar y evaluar modelos para detectar la fase REM del sueño en un paciente. Disponemos de datos de 3 pacientes: los pacientes 9 y 10 para entrev¡nar los modelos, y el paciente 8 para realizar el test de los modelosInformación de los datos:
El PSG está compuesto por las señales de los siguientes 11 canales:
6 EEG (F3, C3, O1, F4, C4 y O2);
2 EOG, derecho e izquierdo (ROC y LOC);
2 tipos de EMG (un m. submentalis – EMG de la barbilla (X1) – y dos m. tibialis – EMG de las piernas);
Las referencias se colocaron en los lóbulos de las orejas izquierda y derecha (A1, A2).
La carga de los datos del paciente se realizará en dos fases:
table(info_8$Stage)
##
## N1 N2 N3 R W
## 119 201 143 161 376
table(info_9$Stage)
##
## N1 N2 N3 R w W
## 167 365 225 63 6 143
table(info_10$Stage)
##
## N1 N2 N3 R W
## 221 181 112 135 147
Es importante descartar las últimas 30 por alto ruido, así el número de épocas válidas:
nvalid_epochs_8 <-length(info_8$Stage)-30
nvalid_epochs_9 <-length(info_9$Stage)-30
nvalid_epochs_10 <-length(info_10$Stage)-30
Se generarán las variables/características a partir de cada época de los canales. Es decir, de cada 30 segundos de cada canal obtendremos los estadísticos que mejor estimen la fase.
Como ejemplo, se ha determinado dos características temporales (media y desviación estándar) y una frecuencial (frecuencia dominante)
## Class
## REM noREM
## 161 839
## Class
## REM noREM
## 63 906
## Class
## REM noREM
## 135 661
Este es el dataframe con el que vamos a trabajar, contiene información estadística de cada variable. Veamos la información del paciente número 10.
head(X_10)
## meanF3 sdF3 dfreqF3 meanC3 sdC3 dfreqC3
## 1 -0.02129953 3.307478 0.0004333333 -0.008281218 2.472070 0.0008000000
## 2 -0.01296735 1.869238 0.0010333333 -0.013510734 1.615755 0.0006000000
## 3 0.02740744 2.712249 0.0005666667 0.022949837 2.522197 0.0005666667
## 4 0.01835558 1.244372 0.0011666667 0.020559991 1.116382 0.0013666667
## 5 0.03335134 3.232505 0.0004333333 0.012417492 2.666751 0.0004666667
## 6 -0.03305958 2.692889 0.0005666667 -0.018666463 2.434598 0.0005666667
## meanO1 sdO1 dfreqO1 meanF4 sdF4 dfreqF4
## 1 0.04224522 3.501601 0.0010000000 0.005707692 3.186690 0.0004333333
## 2 -0.01275098 1.261451 0.0006000000 -0.018819259 1.821697 0.0010333333
## 3 0.02080345 3.820320 0.0005666667 -0.007128261 3.419440 0.0008000000
## 4 0.02325180 1.335003 0.0013666667 0.010521794 1.176076 0.0006333333
## 5 0.04738654 6.100085 0.0004666667 0.014494943 3.137274 0.0007000000
## 6 -0.02174191 3.257066 0.0011666667 0.010703275 2.446140 0.0007000000
## meanC4 sdC4 dfreqC4 meanO2 sdO2 dfreqO2
## 1 0.015586052 2.032316 0.0005666667 0.063933988 3.948716 0.0005666667
## 2 -0.015209408 1.937843 0.0012000000 -0.019611030 1.404053 0.0010000000
## 3 -0.001417834 3.374258 0.0005000000 0.011801070 5.218386 0.0004333333
## 4 0.009785975 1.210531 0.0008333333 0.008987494 1.028485 0.0008333333
## 5 0.020133201 2.677501 0.0007666667 0.034537732 5.370663 0.0004666667
## 6 0.015018397 1.883838 0.0007333333 0.040961359 2.707964 0.0007333333
## meanROC sdROC dfreqROC meanLOC sdLOC dfreqLOC
## 1 0.013710203 3.291193 0.0005333333 -0.049747498 3.476788 0.0005333333
## 2 -0.032472221 1.867236 0.0005000000 -0.008744617 2.829082 0.0010333333
## 3 -0.004739106 3.426522 0.0004000000 -0.006625409 3.149877 0.0008333333
## 4 0.012049015 1.639314 0.0005333333 -0.007131169 1.777918 0.0005000000
## 5 -0.022030839 3.836607 0.0005666667 0.047807370 3.381188 0.0004000000
## 6 0.106578943 5.168857 0.0007333333 -0.076459227 5.937141 0.0007333333
## meanX1 sdX1 dfreqX1 meanX2 sdX2 dfreqX2
## 1 0.001595752 3.931016 0.04563333 0.026736492 14.430362 0.02596667
## 2 -0.004571017 2.152961 0.02166667 -0.025991514 4.280309 0.01650000
## 3 -0.005670558 3.042263 0.04663333 0.001860911 18.431195 0.02193333
## 4 -0.001394683 0.756855 0.02443333 -0.053518532 4.119712 0.01526667
## 5 -0.002313669 2.727392 0.04806667 -0.003461203 15.035419 0.03526667
## 6 0.003494899 2.412373 0.04613333 -0.010797046 12.203041 0.03426667
## meanX3 sdX3 dfreqX3 Class
## 1 -0.0064104957 2.9799832 0.03380000 noREM
## 2 -0.0002292077 0.2014679 0.03126667 noREM
## 3 0.0016962288 0.3348723 0.01483333 noREM
## 4 -0.0003637455 0.6656458 0.02786667 noREM
## 5 0.0047002671 2.6019201 0.02630000 noREM
## 6 0.0068909113 3.4868105 0.01786667 noREM
Unimos los datos de entrenamientoseran el paciente 9 y 10. Testearemos los resultado de los modelos implementados con el paciente 8.
X_train <- rbind.data.frame(X_9,X_10)
QQ-plot de cada variable, el color representa el estado del paciente, si es REM o NoRem.
Con estos gráficos podemos comprobar de un modo visual y rápido si las variables siguen la misma distribución para las fases noREM como para la fase REM.
## geom_draw_grob: grob = list(x = 0.5, y = 0.5, name = "GRID.null.607", gp = NULL, vp = NULL), xmin = 0, xmax = 0.2, ymin = 0, ymax = 0.2, scale = 1, clip = inherit, halign = 0.5, valign = 0.5
## stat_identity: na.rm = FALSE
## position_identity
## geom_draw_grob: grob = list(x = 0.5, y = 0.5, name = "GRID.null.1202", gp = NULL, vp = NULL), xmin = 0, xmax = 0.2, ymin = 0, ymax = 0.2, scale = 1, clip = inherit, halign = 0.5, valign = 0.5
## stat_identity: na.rm = FALSE
## position_identity
## geom_draw_grob: grob = list(x = 0.5, y = 0.5, name = "GRID.null.1797", gp = NULL, vp = NULL), xmin = 0, xmax = 0.2, ymin = 0, ymax = 0.2, scale = 1, clip = inherit, halign = 0.5, valign = 0.5
## stat_identity: na.rm = FALSE
## position_identity
En algunas de nuestras variables podemos encontrar que ambas clases siguen distribuciones muy similares, pero en otras muchas si difieren, sobretodo en aquellas en las que el estadístico del canal que estamos utilizando es la desviación estándar.
Esto nos indica que posiblemente un comportamiento distinto de la desviación estándar puede significar una diferencia importante en la clasificación posterior.
Boxplot de cada variable, dependiendo de si la fase es Rem o NoRem.
plot1<-plot_boxplot(X_train %>% select(starts_with('mean'),'Class'), by = "Class")
plot1
## $page_1
plot2<-plot_boxplot(X_train %>% select(starts_with('sd'),'Class'), by = "Class")
plot2
## $page_1
plot3<-plot_boxplot(X_train %>% select(starts_with('dfreq'),'Class'), by = "Class")
plot3
## $page_1
Una vez analizados los diagramas de cajas, y teniendo como referencia la conclusión anterior, encontramos otro patrón en los datos. Entre las distribuciones que más parecen diferir (las de la desviación estándar) las más destacables son las pertenecientes a los canales del Encefalograma.
Las variables a priori más destacables las volvemos a visualizar ahora mediante Diagramas de Violín.
Diagrama de violín de las variables más destacables dependiendo de la fase del sueño en el que se encuenre: Rem o NoRem.
plot1 <- ggbetweenstats(X_train, Class, sdF3, type='np', main ='sdF3')
plot2 <- ggbetweenstats(X_train, Class, sdF4, type='np', main ='sdF4')
plot3 <- ggbetweenstats(X_train, Class, sdC3, type='np', main ='sdC3')
plot4 <- ggbetweenstats(X_train, Class, sdC4, type='np', main ='sdC4')
plot5 <- ggbetweenstats(X_train, Class, sdF3, type='np', main ='sdO1')
plot6 <- ggbetweenstats(X_train, Class, sdF4, type='np', main ='sdO2')
plot7 <- ggbetweenstats(X_train, Class, sdX1, type='np', main ='dfreqX1')
plot8 <- ggbetweenstats(X_train, Class, meanX2, type='np', main ='meanX2')
plot9 <- ggbetweenstats(X_train, Class, sdROC, type='np', main ='sdROC')
Correlación entre las variables.
Podemos apreciar una gran correlación entre las varaibles meanC4 y meanC4. En general, observamos que existe correlación entre aquellos canales correspondiente al EEG. En cambio las variables corresponde¡ientes al EOC y EMG no presentan mucha correlación. Algo a destacar es que la media de LOC y ROC no presentan nada de correlación.
En este gráfico sobre el estadístico de la dispersión, se presenta una mayor correlación. Las variables pertenecientes a EEG, sobretodo, estan bastante correlacionadas. Luego, podemos observar correlación entre la sd de X2 y X3 y de sdLOC y sdROC.
Observamos valores parecidos a la media.
Las normalizaciones que vamos a utilizaremos en el proyecto van a ser la estándar y la MinMax.
library(caret)
## Loading required package: lattice
set.seed(123)
standarizer <- preProcess(X_train, method=c('center','scale')) # Normalización estándar
X_train_std <- predict(standarizer, X_train)
X_test_std <- predict(standarizer, X_8)
min_max_scaler <- preProcess(X_train, method='range') # Normalización estándar
X_train_mm <- predict(min_max_scaler, X_train)
X_test_mm <- predict(min_max_scaler, X_8)
Realizamos un modelo Lineal Generalizado Mixto. Se ha realizado el modelo con los datos sin normalizar, con la normalización estándar y con la normalización MInMax.
set.seed(1234)
model_glm<- glm(Class ~ . ,
data = X_train_std,
family = "binomial")
## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred
summary(model_glm)
##
## Call:
## glm(formula = Class ~ ., family = "binomial", data = X_train_std)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -8.4904 0.0003 0.0217 0.1293 3.0326
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) 10.000343 0.830211 12.046 < 2e-16 ***
## meanF3 0.275254 0.742298 0.371 0.710776
## sdF3 -0.138862 1.900661 -0.073 0.941759
## dfreqF3 0.067316 0.168271 0.400 0.689123
## meanC3 -0.363162 0.718670 -0.505 0.613331
## sdC3 -1.149038 2.157550 -0.533 0.594334
## dfreqC3 -0.198985 0.163755 -1.215 0.224312
## meanO1 -0.421569 0.474832 -0.888 0.374634
## sdO1 8.941531 2.042670 4.377 1.2e-05 ***
## dfreqO1 0.490798 0.179660 2.732 0.006299 **
## meanF4 -0.025685 0.752034 -0.034 0.972754
## sdF4 0.415757 1.862764 0.223 0.823385
## dfreqF4 -0.137239 0.186566 -0.736 0.461969
## meanC4 0.323181 0.691750 0.467 0.640362
## sdC4 1.195792 1.994491 0.600 0.548808
## dfreqC4 -0.202218 0.144511 -1.399 0.161715
## meanO2 -0.113346 0.452094 -0.251 0.802036
## sdO2 -1.257626 0.937024 -1.342 0.179548
## dfreqO2 0.276298 0.178860 1.545 0.122401
## meanROC 0.068219 0.308083 0.221 0.824757
## sdROC -0.314177 0.613134 -0.512 0.608363
## dfreqROC 0.016725 0.170179 0.098 0.921710
## meanLOC -0.089499 0.194903 -0.459 0.646092
## sdLOC -0.593784 0.465245 -1.276 0.201856
## dfreqLOC 0.002064 0.170683 0.012 0.990351
## meanX1 2.657340 0.923917 2.876 0.004025 **
## sdX1 14.515934 1.625353 8.931 < 2e-16 ***
## dfreqX1 0.512398 0.132862 3.857 0.000115 ***
## meanX2 -0.349235 0.216532 -1.613 0.106775
## sdX2 -1.039252 0.333706 -3.114 0.001844 **
## dfreqX2 0.163426 0.160454 1.019 0.308431
## meanX3 -0.679531 0.575520 -1.181 0.237712
## sdX3 1.116872 0.576441 1.938 0.052680 .
## dfreqX3 0.306267 0.178150 1.719 0.085587 .
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 1224.67 on 1704 degrees of freedom
## Residual deviance: 434.69 on 1671 degrees of freedom
## AIC: 502.69
##
## Number of Fisher Scoring iterations: 11
Mediante el análisis de los p-valores y los coeficientes asociados a las variables nos damos cuenta de que el modelo puede funcionar mejor utilizando un menor número de variables.
Es por esto que mediante el comando step, nos quedamos con las variables que suponen una menor pérdida de información. Este comando lo consigue mediante un análisis del estadistico AIC(Criterio de Información de Akaike).
Los resultados de las variables mas importantes, obtenido mediante en el step, son los mismo indistintamente de si utilizamos alguna normalización o no.
Es por lo mencionado anteriormente, por lo que solo usaremos la normalización estándar.
set.seed(1234)
model_glm2 <- glm(formula = Class ~ meanO1 + sdO1 +
dfreqO1 + dfreqC4 + sdLOC + meanX1 +
sdX1 + dfreqX1 + sdX2 + sdX3 +
dfreqX3, family = "binomial",
data = X_train_std)
## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred
summary(model_glm2)
##
## Call:
## glm(formula = Class ~ meanO1 + sdO1 + dfreqO1 + dfreqC4 + sdLOC +
## meanX1 + sdX1 + dfreqX1 + sdX2 + sdX3 + dfreqX3, family = "binomial",
## data = X_train_std)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -8.4904 0.0004 0.0234 0.1366 3.1751
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) 9.8461 0.7679 12.822 < 2e-16 ***
## meanO1 -0.5059 0.3274 -1.545 0.12230
## sdO1 7.8122 0.9070 8.613 < 2e-16 ***
## dfreqO1 0.4652 0.1690 2.752 0.00592 **
## dfreqC4 -0.2471 0.1244 -1.987 0.04688 *
## sdLOC -0.8060 0.1585 -5.085 3.67e-07 ***
## meanX1 2.3217 0.8603 2.699 0.00696 **
## sdX1 14.5579 1.4960 9.731 < 2e-16 ***
## dfreqX1 0.5073 0.1243 4.080 4.50e-05 ***
## sdX2 -0.6543 0.2231 -2.932 0.00336 **
## sdX3 0.7511 0.4402 1.706 0.08793 .
## dfreqX3 0.3317 0.1646 2.015 0.04385 *
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 1224.67 on 1704 degrees of freedom
## Residual deviance: 443.48 on 1693 degrees of freedom
## AIC: 467.48
##
## Number of Fisher Scoring iterations: 11
Vemos que obtenemos unos altos valores de \(\beta\) en sdX1; por lo que podemos entender que es una variable que es muy importante para el modelo.
Veamos las predicciones con el conjunto test:
set.seed(1234)
# Resultados Datos Standard
y_Class_hat_std <- predict(model_glm2,
newdata=X_test_std,
type="response")
th_opt<-pROC::coords(roc(X_test_std$Class,y_Class_hat_std),"best",
best.method="closest.topleft")$threshold
## Setting levels: control = REM, case = noREM
## Setting direction: controls < cases
y_hat<-factor(y_Class_hat_std<th_opt,
levels=c(TRUE,FALSE),
labels = c("REM","noREM"))
caret::confusionMatrix(y_hat,X_test_std$Class)
## Confusion Matrix and Statistics
##
## Reference
## Prediction REM noREM
## REM 133 77
## noREM 16 744
##
## Accuracy : 0.9041
## 95% CI : (0.8838, 0.9219)
## No Information Rate : 0.8464
## P-Value [Acc > NIR] : 8.224e-08
##
## Kappa : 0.6842
##
## Mcnemar's Test P-Value : 4.918e-10
##
## Sensitivity : 0.8926
## Specificity : 0.9062
## Pos Pred Value : 0.6333
## Neg Pred Value : 0.9789
## Prevalence : 0.1536
## Detection Rate : 0.1371
## Detection Prevalence : 0.2165
## Balanced Accuracy : 0.8994
##
## 'Positive' Class : REM
##
NOTA ¿Que va a ser importante para nuestros modelos? Vamos a tener en cuenta para este problema las siguientes métricas: - Accuracy - Area Under Curver - Positive Predicted Values (% Rem detectados) - Sensibilidad y Sensitividad
Vemos como la accuracy es del 90%, lo cual no es malo, pero podemos observar como el Positive Predicted Values es bastante bajo 63’33%. Nuestro reto es conseguir aumentar la tasa de PPV, sin tener una bajada de la sensitividad.
set.seed(1234)
y_test_roc <- X_test_std %>%
mutate(Class = ifelse(Class == "REM",1,0))
Class_hat_roc<-ifelse(y_hat=="REM",1,0)
roc_test <- roc(response= y_test_roc$Class,
predictor=Class_hat_roc,
quiet=TRUE,
plot=F)
plot.roc(roc_test,print.auc=TRUE,col="blue",
xlab="1-ESpecificidad",ylab="Sensibilidad")
Vamos a realizar un Random Forest con los datos sin normalizar y los datos normalizados para poder observar cual obtiene un mejor resultado.
set.seed(12345)
model_rf1=randomForest(Class ~ . ,
data = X_train,
family = "binomial",
ntree=1000,
classwt=c(0.8,0.2),
mtry=10)
varImpPlot(model_rf1)
Vemos como el modelo sobrepone las variables correspondientes a la desviación típica como las más importantes. Sobretodo, observamos una gran importacia a la variable sdX1.
Veamos los resultados:
prob_Class_hat<-predict(model_rf1,
newdata=X_8,
type="prob")[,"REM"]
th_opt<-pROC::coords(roc(X_8$Class,prob_Class_hat),"best",
best.method="closest.topleft")$threshold
## Setting levels: control = REM, case = noREM
## Setting direction: controls > cases
Class_hat<-factor(prob_Class_hat>th_opt,
levels=c(TRUE,FALSE),
labels = c("REM","noREM"))
caret::confusionMatrix(Class_hat,X_8$Class)
## Confusion Matrix and Statistics
##
## Reference
## Prediction REM noREM
## REM 136 62
## noREM 13 759
##
## Accuracy : 0.9227
## 95% CI : (0.904, 0.9387)
## No Information Rate : 0.8464
## P-Value [Acc > NIR] : 5.747e-13
##
## Kappa : 0.7379
##
## Mcnemar's Test P-Value : 2.981e-08
##
## Sensitivity : 0.9128
## Specificity : 0.9245
## Pos Pred Value : 0.6869
## Neg Pred Value : 0.9832
## Prevalence : 0.1536
## Detection Rate : 0.1402
## Detection Prevalence : 0.2041
## Balanced Accuracy : 0.9186
##
## 'Positive' Class : REM
##
Observamos una mejora con el accuracy respecto al modelo lineal. Seguimos sin conseguir aumentar la tasa de PPV.
set.seed(12345)
model_rf2=randomForest(Class ~ . ,
data = X_train_std,
family = "binomial",
ntree=1000,
classwt=c(0.8,0.2),
mtry=10)
# varimp(model_rf2)
El gráfico de variables más importantes para el modelo es muy similar al anterior.
prob_Class_hat<-predict(model_rf2,
newdata=X_test_std,
type="prob")[,"REM"]
th_opt<-pROC::coords(roc(X_test_std$Class,prob_Class_hat),"best",
best.method="closest.topleft")$threshold
## Setting levels: control = REM, case = noREM
## Setting direction: controls > cases
Class_hat<-factor(prob_Class_hat>th_opt,
levels=c(TRUE,FALSE),
labels = c("REM","noREM"))
caret::confusionMatrix(Class_hat,X_test_std$Class)
## Confusion Matrix and Statistics
##
## Reference
## Prediction REM noREM
## REM 136 65
## noREM 13 756
##
## Accuracy : 0.9196
## 95% CI : (0.9007, 0.9359)
## No Information Rate : 0.8464
## P-Value [Acc > NIR] : 5.579e-12
##
## Kappa : 0.7294
##
## Mcnemar's Test P-Value : 7.713e-09
##
## Sensitivity : 0.9128
## Specificity : 0.9208
## Pos Pred Value : 0.6766
## Neg Pred Value : 0.9831
## Prevalence : 0.1536
## Detection Rate : 0.1402
## Detection Prevalence : 0.2072
## Balanced Accuracy : 0.9168
##
## 'Positive' Class : REM
##
Observamos que el modelo con los datos normalizados mediante la normalización estándar obtiene unos peores resultados. El Accuracy y los PPV son inferiores.
library(randomForest)
set.seed(12345)
model_rf3 = randomForest(Class ~ . ,
data = X_train_mm,
family = "binomial",
ntree=1000,
classwt=c(0.8,0.2),
mtry=10)
#varImpPlot(model_rf3)
prob_Class_hat<-predict(model_rf3,
newdata=X_test_mm,
type="prob")[,"REM"]
th_opt<-pROC::coords(roc(X_test_mm$Class,prob_Class_hat),"best",
best.method="closest.topleft")$threshold
## Setting levels: control = REM, case = noREM
## Setting direction: controls > cases
Class_hat<-factor(prob_Class_hat>th_opt,
levels=c(TRUE,FALSE),
labels = c("REM","noREM"))
caret::confusionMatrix(Class_hat,X_test_mm$Class)
## Confusion Matrix and Statistics
##
## Reference
## Prediction REM noREM
## REM 137 68
## noREM 12 753
##
## Accuracy : 0.9175
## 95% CI : (0.8984, 0.9341)
## No Information Rate : 0.8464
## P-Value [Acc > NIR] : 2.370e-11
##
## Kappa : 0.7251
##
## Mcnemar's Test P-Value : 7.788e-10
##
## Sensitivity : 0.9195
## Specificity : 0.9172
## Pos Pred Value : 0.6683
## Neg Pred Value : 0.9843
## Prevalence : 0.1536
## Detection Rate : 0.1412
## Detection Prevalence : 0.2113
## Balanced Accuracy : 0.9183
##
## 'Positive' Class : REM
##
Observamos unos resultados muy parecidos a los anteriores. Es por ello que el Random Forest funciona de mejor forma con los datos sin normalizar. Ahora, vamos a encontrar aquellos parámetros que mejor ajusten el modelo; en este caso, vamos a utilizar los datos sin normalizar debido a sus resultados.
control <- trainControl(method="repeatedcv",
number=10,
repeats=3,
classProbs=TRUE,
search="grid")
set.seed(1234)
tunegrid <- expand.grid(.mtry=c(1:15))
rf_gridsearch <- train(Class~.,
data=X_train,
method="rf",
metric='ROC',
tuneGrid=tunegrid,
trControl=control)
## Warning in train.default(x, y, weights = w, ...): The metric "ROC" was not in
## the result set. Accuracy will be used instead.
print(rf_gridsearch)
## Random Forest
##
## 1705 samples
## 33 predictor
## 2 classes: 'REM', 'noREM'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 3 times)
## Summary of sample sizes: 1534, 1535, 1535, 1534, 1534, 1535, ...
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 1 0.9312003 0.5648696
## 2 0.9554292 0.7558853
## 3 0.9640292 0.8118266
## 4 0.9659854 0.8250991
## 5 0.9663775 0.8293471
## 6 0.9679416 0.8392821
## 7 0.9685252 0.8427477
## 8 0.9685252 0.8431040
## 9 0.9685241 0.8439897
## 10 0.9693084 0.8476846
## 11 0.9693050 0.8486959
## 12 0.9685252 0.8443647
## 13 0.9691100 0.8475604
## 14 0.9677409 0.8403418
## 15 0.9675460 0.8396246
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 10.
plot(rf_gridsearch)
rf_gridsearch$bestTune
## mtry
## 10 10
prob_Class_hat<-predict(model_rf1,
newdata=X_8,
type="prob")[,"REM"]
th_opt<-pROC::coords(roc(X_8$Class,prob_Class_hat),"best",
best.method="closest.topleft")$threshold
## Setting levels: control = REM, case = noREM
## Setting direction: controls > cases
Class_hat<-factor(prob_Class_hat>th_opt,
levels=c(TRUE,FALSE),
labels = c("REM","noREM"))
caret::confusionMatrix(Class_hat,X_8$Class)
## Confusion Matrix and Statistics
##
## Reference
## Prediction REM noREM
## REM 136 62
## noREM 13 759
##
## Accuracy : 0.9227
## 95% CI : (0.904, 0.9387)
## No Information Rate : 0.8464
## P-Value [Acc > NIR] : 5.747e-13
##
## Kappa : 0.7379
##
## Mcnemar's Test P-Value : 2.981e-08
##
## Sensitivity : 0.9128
## Specificity : 0.9245
## Pos Pred Value : 0.6869
## Neg Pred Value : 0.9832
## Prevalence : 0.1536
## Detection Rate : 0.1402
## Detection Prevalence : 0.2041
## Balanced Accuracy : 0.9186
##
## 'Positive' Class : REM
##
Vamos a interpretar los resultados del RF con la ayuda de la librería IML - Interprete Machine Learning.
library(iml)
features <- X_train %>% dplyr::select(-Class)
response <- as.numeric(X_train$Class=="REM")
table(X_train$Class)
##
## REM noREM
## 198 1507
table(response)
## response
## 0 1
## 1507 198
pred_RF <- function(model, newdata) {
predict(model, newdata, type="prob")[,"REM"]
}
predictor_RF <- Predictor$new(
model = model_rf1,
data = features,
y = response,
predict.fun = pred_RF,
class = "classification"
)
imp_RF <- FeatureImp$new(predictor_RF, loss = "ce",n.repetitions = 5)
plot(imp_RF) + ggtitle("imp_RF")
Podemos observar que variables son más importantes finalmente para el modelo. Observamos cuatro que estan bastante diferenciadas al resto: sdO1, sdC4, sdO1, sdO2. Posteriormente, encontramos otro grupo de tres variables: sdF3, sdROC y sdF4 las cuales son bastante importantes pero a un menor nivel respecto a las mencionadas. Podriamos proseguir con sdLOC y sdX1.
Ahora vamos a seleccionar las 14 primeras variables más importantes para construir el modelo; veamos los resultados:
set.seed(1233)
var_sel<-imp_RF$results$feature[1:14]
data_sel1<-X_train %>% dplyr::select(all_of(var_sel),"Class")
model_RF_sel1 <- randomForest(Class ~ . , data = data_sel1,ntree=1000,classwt=c(0.8,0.2),mtry=10)
p_RF<-predict(model_RF_sel1,X_8, type="prob")[,"REM"]
th_opt<-pROC::coords(roc(X_8$Class,p_RF),"best",
best.method="closest.topleft")$threshold
## Setting levels: control = REM, case = noREM
## Setting direction: controls > cases
Class_hat<-factor(p_RF>th_opt,levels=c(TRUE,FALSE),labels = c("REM","noREM"))
caret::confusionMatrix(Class_hat,X_8$Class)
## Confusion Matrix and Statistics
##
## Reference
## Prediction REM noREM
## REM 139 68
## noREM 10 753
##
## Accuracy : 0.9196
## 95% CI : (0.9007, 0.9359)
## No Information Rate : 0.8464
## P-Value [Acc > NIR] : 5.579e-12
##
## Kappa : 0.7332
##
## Mcnemar's Test P-Value : 1.090e-10
##
## Sensitivity : 0.9329
## Specificity : 0.9172
## Pos Pred Value : 0.6715
## Neg Pred Value : 0.9869
## Prevalence : 0.1536
## Detection Rate : 0.1433
## Detection Prevalence : 0.2134
## Balanced Accuracy : 0.9250
##
## 'Positive' Class : REM
##
Podemos observar como, prediciendo con las 14 variables más importantes para el modelo conseguimos aumentar la tasa de PPV sin que la sensitividad disminuya y, como consequencia, aumentamos el accuracy.
Ahora vamos a seleccionar las variables que miden la desviación típica:
set.seed(1233)
data_sel2<-X_train %>% dplyr::select(starts_with('sd'),"Class")
model_RF_sel2<-randomForest(Class ~ . , data = data_sel2,ntree=1000,classwt=c(0.8,0.2),mtry=10)
p_RF<-predict(model_RF_sel2,X_8, type="prob")[,"REM"]
th_opt<-pROC::coords(roc(X_8$Class,p_RF),"best",
best.method="closest.topleft")$threshold
## Setting levels: control = REM, case = noREM
## Setting direction: controls > cases
Class_hat<-factor(p_RF>th_opt,levels=c(TRUE,FALSE),labels = c("REM","noREM"))
caret::confusionMatrix(Class_hat,X_8$Class)
## Confusion Matrix and Statistics
##
## Reference
## Prediction REM noREM
## REM 135 59
## noREM 14 762
##
## Accuracy : 0.9247
## 95% CI : (0.9063, 0.9405)
## No Information Rate : 0.8464
## P-Value [Acc > NIR] : 1.177e-13
##
## Kappa : 0.7424
##
## Mcnemar's Test P-Value : 2.607e-07
##
## Sensitivity : 0.9060
## Specificity : 0.9281
## Pos Pred Value : 0.6959
## Neg Pred Value : 0.9820
## Prevalence : 0.1536
## Detection Rate : 0.1392
## Detection Prevalence : 0.2000
## Balanced Accuracy : 0.9171
##
## 'Positive' Class : REM
##
Obtenemos un aumento del PPV pero una reducción de la sensitividad, y obteniendo el mismo accuracy. Debido al gran pequeño aumento en la tasa del PPV, y una peor sensitividad; seleccionamos el modelo co la información de las 14 primeras variables como el que mejor Random Forest.
Vamos a realizar una máquina de vector soporte para detectar la fase REM. Vamos a mejorar el modelo viendo que hiperparámetros mejoran el modelo.
set.seed(1234)
table(X_train$Class)
##
## REM noREM
## 198 1507
library(yaImpute)
##
## Attaching package: 'yaImpute'
## The following object is masked from 'package:e1071':
##
## impute
## The following object is masked from 'package:lava':
##
## vars
## The following object is masked from 'package:ggplot2':
##
## vars
## The following object is masked from 'package:dplyr':
##
## vars
tuning <- tune(svm, Class ~ ., data = X_train_std,
ranges = list(gamma = seq(0.025,0.15,0.025), cost = seq(8,10,1)),
class.weights = c('REM'=0.7,'noREM'=0.3),
scale = TRUE
)
tuning$best.parameters
## gamma cost
## 14 0.05 10
library(plotly)
ggplotly(ggplot(data = tuning$performances, aes(x = cost, y = error,col=gamma)) +
#geom_line() +
geom_point() +
labs(title = "Error de validación ~ hiperparámetro C y gamma") +
theme_bw() +
theme(plot.title = element_text(hjust = 0.5)))
prediccion <- predict(tuning$best.model,X_test_std)
caret::confusionMatrix(prediccion,X_test_std$Class)
## Confusion Matrix and Statistics
##
## Reference
## Prediction REM noREM
## REM 94 16
## noREM 55 805
##
## Accuracy : 0.9268
## 95% CI : (0.9086, 0.9424)
## No Information Rate : 0.8464
## P-Value [Acc > NIR] : 2.273e-14
##
## Kappa : 0.6847
##
## Mcnemar's Test P-Value : 6.490e-06
##
## Sensitivity : 0.63087
## Specificity : 0.98051
## Pos Pred Value : 0.85455
## Neg Pred Value : 0.93605
## Prevalence : 0.15361
## Detection Rate : 0.09691
## Detection Prevalence : 0.11340
## Balanced Accuracy : 0.80569
##
## 'Positive' Class : REM
##
Vemos como el modelo mejora basyante el Positive Predicted Values, osea, la tasa de REM acertada es más alta que en los otros modelos. No obstante, la sensitividad es bastante inferior a los modelos anteriores, lo que equilibra el accuracy para que sea parecida al RF diseñado anteriormente.
El principal problema, y sobretodo en el área clínica és la interpretabilidad. Necesitamos que el modelo sea interpretable por el programador y, por tanto, comprensible. Como podemos ver, aunque el SVM no es un modelo que haga una mala detección, es un modelo que tiene una difícil interpretacción del proceso de predicción de los datos. Por tanto, en este caso es más adecuado usar el modelo atnterior, el Random Forest, ya que en este sabemos que variables interpreta el árbol, y de que forma.
X_train_std_tan <- X_train_std %>%
mutate(Class = ifelse(Class == "REM",1,0))
X_test_std_tan <- X_test_std %>%
mutate(Class = ifelse(Class == "REM",1,0))
X_train_std_tan$Class <- as.factor(X_train_std_tan$Class)
X_test_std_tan$Class <- as.factor(X_test_std_tan$Class)
Discretizamos lo datos para poder ejecutar el modelo basado en reglas:
set.seed(1233)
library(arulesCBA)
## Loading required package: arules
##
## Attaching package: 'arules'
## The following object is masked from 'package:recipes':
##
## discretize
## The following object is masked from 'package:car':
##
## recode
## The following object is masked from 'package:mc2d':
##
## lhs
## The following object is masked from 'package:effectsize':
##
## rules
## The following object is masked from 'package:dplyr':
##
## recode
## The following objects are masked from 'package:base':
##
## abbreviate, write
##
## Attaching package: 'arulesCBA'
## The following object is masked from 'package:arules':
##
## rules
## The following object is masked from 'package:effectsize':
##
## rules
data_to_disc <- rbind.data.frame(X_train_std_tan,X_test_std_tan)
data_disc<- arulesCBA::discretizeDF.supervised(formula=Class~.,
data=data_to_disc,
method = "mdlp")
Dividimos los datos discretizados en el conjunto train y test.
# Split:
data_disc_train <- data_disc[1:nrow(X_train_std_tan),]
data_disc_test <- data_disc[(nrow(X_train_std_tan)+1):(nrow(data_disc)),]
Balanceamos las clases para conseguir una mejor interpretacción por el modelo.
library(ROSE)
## Loaded ROSE 0.0-4
set.seed(9560)
rose_train <- ROSE(Class ~ ., data = data_disc_train)$data
table(rose_train$Class)
##
## 0 1
## 826 879
Ejecutamos el modelo TAN:
library(caret)
#library(bnclassify)
fitControl <- trainControl(method = "repeatedcv",
number=2, repeats=5,
classProbs = TRUE,
summaryFunction = twoClassSummary,
verbose=F)
tune_grid <- expand.grid(smooth=10^seq(-5,5,0.5),#seq(-5,5,0.5),
score='loglik')#score=c('loglik', 'bic', 'aic'))
set.seed(666)
data_model <- rose_train %>%
mutate(Class = ifelse(Class == 1,'si','no'))
x<-data_model[,names(data_model)!="Class"]
y<-data_model$Class
set.seed(666)
caret_tan <- caret::train(x,y,
method = "tan",
trControl = fitControl,
tuneGrid = tune_grid,
metric = "ROC",
maximize=TRUE
)
caret_tan$bestTune
## score smooth
## 8 loglik 0.03162278
ggplot(data=caret_tan$results,aes(x=smooth,y=ROC,color=score))+
geom_point(size=2, shape=21) +
geom_line()+
geom_errorbar(aes(ymin=ROC-ROCSD/2, ymax=ROC+ROCSD/2), width=.05, alpha=0.5)+
scale_x_continuous(trans='log10')
set.seed(1233)
library(gRain)
## Loading required package: gRbase
##
## Attaching package: 'gRbase'
## The following object is masked from 'package:signal':
##
## triang
## The following objects are masked from 'package:igraph':
##
## is_dag, topo_sort
## The following object is masked from 'package:compiler':
##
## compile
## The following objects are masked from 'package:lava':
##
## ancestors, children, edgeList, latent, latent<-, ordinal,
## ordinal<-, parents
## The following object is masked from 'package:scales':
##
## ordinal
## The following objects are masked from 'package:generics':
##
## compile, fit
## The following object is masked from 'package:R.oo':
##
## compile
x_test <- data_disc_test[,names(data_disc_test)!="Class"]
y_test <- data_disc_test$Class
p_test<-predict(caret_tan, newdata=x_test, type="prob")
roc_test<-roc(response=y_test,predictor=p_test[,"si"],quiet=TRUE,plot=F)
roc_test
##
## Call:
## roc.default(response = y_test, predictor = p_test[, "si"], quiet = TRUE, plot = F)
##
## Data: p_test[, "si"] in 821 controls (y_test 0) < 149 cases (y_test 1).
## Area under the curve: 0.9322
plot.roc(roc_test,print.auc=T,print.thres = "best",
col="blue",xlab="1-ESpecificidad",ylab="Sensibilidad")
set.seed(1233)
th_opt<-pROC::coords(roc(X_8$Class,p_test[,'si']),"best",
best.method="closest.topleft")$threshold
## Setting levels: control = REM, case = noREM
## Setting direction: controls > cases
predicted_values<-ifelse(p_test[,'si']<th_opt,'0','1')
predicted_values <- as.numeric(predicted_values)
predicted_values <- as.factor(predicted_values)
caret::confusionMatrix(predicted_values,y_test)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 710 20
## 1 111 129
##
## Accuracy : 0.8649
## 95% CI : (0.8418, 0.8858)
## No Information Rate : 0.8464
## P-Value [Acc > NIR] : 0.05773
##
## Kappa : 0.5845
##
## Mcnemar's Test P-Value : 3.74e-15
##
## Sensitivity : 0.8648
## Specificity : 0.8658
## Pos Pred Value : 0.9726
## Neg Pred Value : 0.5375
## Prevalence : 0.8464
## Detection Rate : 0.7320
## Detection Prevalence : 0.7526
## Balanced Accuracy : 0.8653
##
## 'Positive' Class : 0
##
classifier_tan<-caret_tan$finalModel
plot(classifier_tan)
Veamos otro modelo basado en reglas:
set.seed(1234)
library(arulesCBA)
classifier_CBA <- CBA(Class ~ .,data_disc_train, supp = 0.05, conf=0.95,)
classifier_CBA
## CBA Classifier Object
## Formula: Class ~ .
## Number of rules: 73
## Default Class: NA
## Classification method: first
## Description: CBA algorithm (Liu et al., 1998)
set.seed(1234)
prediction_CBA <- predict(classifier_CBA,data_disc_test)
caret::confusionMatrix(prediction_CBA,data_disc_test$Class)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 812 50
## 1 9 99
##
## Accuracy : 0.9392
## 95% CI : (0.9222, 0.9534)
## No Information Rate : 0.8464
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.7364
##
## Mcnemar's Test P-Value : 1.913e-07
##
## Sensitivity : 0.9890
## Specificity : 0.6644
## Pos Pred Value : 0.9420
## Neg Pred Value : 0.9167
## Prevalence : 0.8464
## Detection Rate : 0.8371
## Detection Prevalence : 0.8887
## Balanced Accuracy : 0.8267
##
## 'Positive' Class : 0
##
vemos que el accuraccy es bueno, pero la especificidad en este caso es bastante peor, podemos observar que el balanced accuracy es peor.
Gracias al paquete iml procedemos a realizar una interpretación del modelo con unos mejores resultados con finalidad de mejorar el valor clínico de nuestro estudio.
Para ello no solo realizaremos la interpretabilidad facilitada por el Random Forest sino que también añadiremos la interpretabilidad que nos ofrecen los modelos de reglas.
Vamos ahora a interpretar el Random Forest con las 14 variables más importantes y, de esta forma, conseguir entender mejor los resultados.
set.seed(1233)
features <- data_sel1 %>%dplyr::select(-Class)
response <- as.numeric(data_sel1$Class=="REM")
table(data_sel1$Class)
##
## REM noREM
## 198 1507
pred_RF <- function(model, newdata) {
predict(model, newdata, type="prob")[,"REM"]
}
predictor_RF <- Predictor$new(
model = model_RF_sel1,
data = features,
y = response,
predict.fun = pred_RF,
class = "classification"
)
set.seed(1233)
imp_RF <- FeatureImp$new(predictor_RF, loss = "logLoss",n.repetitions = 100)
plot(imp_RF) + ggtitle("imp_RF")
imp_RF <- FeatureImp$new(predictor_RF, loss = "ce",n.repetitions = 100)
plot(imp_RF) + ggtitle("imp_RF")
Vemos que dentro del gráfico de las variables más importante para el RF se encuentran las desviaciones de los canales, y la frequencia deominante de X1 y C4.
set.seed(1233)
partial_RF<- FeatureEffect$new(predictor_RF, "sdO1", method="pdp+ice")
p1<-plot(partial_RF) + ggtitle("RF pdp + ice Desviación del canal O1")
partial_RF<- FeatureEffect$new(predictor_RF, "sdX1", method="pdp+ice")
p2<-plot(partial_RF) + ggtitle("RF pdp + ice Desviación del canal X1")
partial_RF<- FeatureEffect$new(predictor_RF, "sdC4", method="pdp+ice")
p3<-plot(partial_RF) + ggtitle("RF pdp + ice Desviación del canal C4")
p1/p2/p3
Podemos observar que en O1 y en C4 obtenemos un mayor efecto en el modelo respecto a la desviación de X1.
set.seed(1233)
ale_RF<- FeatureEffect$new(predictor_RF, c("sdX1"), method="ale")
p1<- plot(ale_RF) + ggtitle("RF ale")
ale_RF<- FeatureEffect$new(predictor_RF, c("sdO1"), method="ale")
p2<-plot(ale_RF) + ggtitle("RF ale")
ale_RF<- FeatureEffect$new(predictor_RF, c("sdC4"), method="ale")
p3<-plot(ale_RF) + ggtitle("RF ale")
(plot_spacer()/p2) | (p1/p3)
Vemos que sdC4 es bastante importante para el modelo debido a su no linealidad; lo que afecta a que glm no la considere una variable de las más importantes, y si lo haga con X1 la cual podemos ver que es lineal.
set.seed(1233)
interact_RF <- Interaction$new(predictor_RF) %>% plot() + ggtitle("RF")
plot(interact_RF)
Obtenemos las variables con una mayor interacción con el modelo. Vemos como sdC4,con sdX1 y sdO1 son las variables que más interactuan con el modelo, sobretodo estas dos últimas. Podemos apreciar que X1 no es una de las variables que el modelo más importáncea le dé, pero si que tiene una gran interacción con el modelo. Veamos más:
set.seed(1233)
interact_RF <- Interaction$new(predictor_RF,feature = "sdX1") %>% plot() + ggtitle("RF,sdX1")
plot(interact_RF)
interact_RF <- Interaction$new(predictor_RF,feature = "sdO1") %>% plot() + ggtitle("RF,sdO1")
plot(interact_RF)
Como podemos ver las variables con un mayor número de interacciones interactúan más con unas variables concretas. Por ejemplo, podemos observar que la desviación típica de O1 interactua con mayor medida con los EMG (X1, X2 y X3). Para poder analizar y observar distintos patrones, realizamos un estudio de la Dependencia parcial de estas interacciones más grandes
set.seed(1233)
partial_RF<- FeatureEffect$new(predictor_RF, c("sdX2",'sdO1'), method="pdp")
p1<- plot(partial_RF) + ggtitle("RF pdp")
partial_RF<- FeatureEffect$new(predictor_RF, c("sdC4",'sdX1'), method="pdp")
p2<-plot(partial_RF) + ggtitle("RF pdp")
partial_RF<- FeatureEffect$new(predictor_RF, c("sdX1",'sdO1'), method="pdp")
p3<- plot(partial_RF) + ggtitle("RF pdp")
partial_RF<- FeatureEffect$new(predictor_RF, c("sdX3",'sdO1'), method="pdp")
p4<- plot(partial_RF) + ggtitle("RF pdp")
library(patchwork)
(p1/p2) | (p3/p4)
En estos gráficos podemos observar como la predicción según el valor de la variable. Podemos ver que cuanto más grande sea sdX2, y más pequeña sdO1, la predicción es mayor.
ggplot(X_train)+ geom_boxplot(aes(x=Class,y=sdC4))
Vemos que las importancia de la variable sdX1 es mucho mayor a las demás. Seguidamente encontramos sdO1 y sdC4 como las mas importantes.
set.seed(1233)
library(yaImpute)
partial_RF<- FeatureEffect$new(predictor_RF, c("sdX2",'sdO1'), method="ale")
p1<- plot(partial_RF) + ggtitle("RF pdp")
partial_RF<- FeatureEffect$new(predictor_RF, c("sdC4",'sdX1'), method="ale")
p2<-plot(partial_RF) + ggtitle("RF pdp")
partial_RF<- FeatureEffect$new(predictor_RF, c("sdX1",'sdO1'), method="ale")
p3<- plot(partial_RF) + ggtitle("RF pdp")
partial_RF<- FeatureEffect$new(predictor_RF, c("sdX3",'sdO1'), method="ale")
p4<- plot(partial_RF) + ggtitle("RF pdp")
(plot_spacer()/p2) | (p1/p3)
set.seed(1233)
tree_RF <- TreeSurrogate$new(predictor_RF, maxdepth = 3)
tree_RF$r.squared
## [1] 0.5396617
plot(tree_RF)
p_tree_RF<-predict(tree_RF,X_8, type="prob")[,".y.hat"]
## Warning in self$predictor$data$match_cols(data.frame(newdata)): Dropping
## additional columns: meanF3, dfreqF3, meanC3, dfreqC3, meanO1, dfreqO1, meanF4,
## dfreqF4, meanC4, meanO2, dfreqO2, dfreqROC, meanLOC, dfreqLOC, meanX1, meanX2,
## dfreqX2, meanX3, dfreqX3, Class
th_opt<-pROC::coords(roc(X_8$Class,p_tree_RF),"best",
best.method="closest.topleft")$threshold
## Setting levels: control = REM, case = noREM
## Setting direction: controls > cases
prediction_tree_RF<-factor(p_tree_RF>th_opt,levels=c(TRUE,FALSE),labels=c("REM","noREM"))
caret::confusionMatrix(prediction_tree_RF,X_8$Class)
## Confusion Matrix and Statistics
##
## Reference
## Prediction REM noREM
## REM 134 95
## noREM 15 726
##
## Accuracy : 0.8866
## 95% CI : (0.8649, 0.9059)
## No Information Rate : 0.8464
## P-Value [Acc > NIR] : 0.0001892
##
## Kappa : 0.6424
##
## Mcnemar's Test P-Value : 4.983e-14
##
## Sensitivity : 0.8993
## Specificity : 0.8843
## Pos Pred Value : 0.5852
## Neg Pred Value : 0.9798
## Prevalence : 0.1536
## Detection Rate : 0.1381
## Detection Prevalence : 0.2361
## Balanced Accuracy : 0.8918
##
## 'Positive' Class : REM
##
Vemos este conjunto de reglas a partir del boxplot white model generado, que se centra, sobretodo en la variable sdC4. Observamos también las variables dfreqX1, sdX1 y sdROC . El resultado es alto cuando 1.12 < sdc4 < 1.267 y sdROC > 1.64. También podemos ver que cuando sdC4>1.12, normalmente la y es pequeña.
Ahora procederemos a realizar las aproximaciones locales mediante modelos lineales con LIME, esta aproximación otorga un efecto a cada variable mostrando como se comporta, esto lo logra asumiendo linealidad.
Por lo tanto procederemos a realizar un estudio de los resultados de los efectos según lime para los valores con las probabilidades más altas y más bajas de nuestro modelo de Random Forest.
set.seed(1233)
high <- predict(model_rf1, X_train, type="prob")[,"REM"] %>% as.vector() %>% which.max()
low <- predict(model_rf1, X_train, type="prob")[,"REM"] %>% as.vector() %>% which.min()
high_prob_ob <- features[high, ]
low_prob_ob <- features[low, ]
high_prob_ob
## sdC3 sdC4 sdO1 sdO2 sdROC sdF3 sdF4
## 1163 0.8959753 0.8378067 0.7466508 0.7011232 1.527213 0.9690622 0.9677567
## sdLOC sdX1 dfreqC4 dfreqX1 sdX3 meanROC sdX2
## 1163 2.056829 0.1085975 0.0006333333 0.0196 0.05964786 0.005175971 3.52611
low_prob_ob
## sdC3 sdC4 sdO1 sdO2 sdROC sdF3 sdF4 sdLOC
## 9 0.7240238 0.6963894 0.6926737 0.6832666 0.9706886 0.7976351 0.7278622 1.02009
## sdX1 dfreqC4 dfreqX1 sdX3 meanROC sdX2
## 9 0.5236844 0.001 0.0308 0.3933177 0.008670708 2.691721
set.seed(1233)
lime_RF <- LocalModel$new(predictor_RF, k = 8 ,x.interest = high_prob_ob) #glmnet
## Loading required package: glmnet
## Loaded glmnet 4.1-3
## Warning in private$aggregate(): Had to choose a smaller k
lime_RF[["results"]] %>% arrange(-effect)
## beta x.recoded effect x.original feature
## 1 0.04600536 2.0568287 0.094625152 2.05682868708848 sdLOC
## 2 0.02083093 3.5261102 0.073452154 3.52611020610991 sdX2
## 3 -0.04000000 0.1085975 -0.004343899 0.108597461010821 sdX1
## 4 -0.01780487 0.7466508 -0.013294018 0.746650823555587 sdO1
## 5 -0.04283079 0.8378067 -0.035883926 0.837806721697553 sdC4
## 6 -0.06922476 0.8959753 -0.062023679 0.895975337591621 sdC3
## 7 -7.24526795 0.0196000 -0.142007252 0.0196 dfreqX1
## feature.value
## 1 sdLOC=2.05682868708848
## 2 sdX2=3.52611020610991
## 3 sdX1=0.108597461010821
## 4 sdO1=0.746650823555587
## 5 sdC4=0.837806721697553
## 6 sdC3=0.895975337591621
## 7 dfreqX1=0.0196
plot(lime_RF) + ggtitle("RF")
set.seed(1233)
lime_RF <- LocalModel$new(predictor_RF, k = 14, x.interest = high_prob_ob) #glmnet
lime_RF[["results"]] %>% arrange(-effect)
## beta x.recoded effect x.original feature
## 1 0.10685633 1.5272133934 0.163192415 1.52721339341384 sdROC
## 2 0.04300201 3.5261102061 0.151629835 3.52611020610991 sdX2
## 3 0.05021143 2.0568286871 0.103276311 2.05682868708848 sdLOC
## 4 0.02719672 0.9677566913 0.026319805 0.967756691287492 sdF4
## 5 0.02633333 0.9690622084 0.025518638 0.969062208401303 sdF3
## 6 12.12094544 0.0006333333 0.007676599 0.000633333333333333 dfreqC4
## 7 -0.21411598 0.0051759711 -0.001108258 0.00517597111327516 meanROC
## 8 -0.02131524 0.0596478604 -0.001271408 0.0596478604429808 sdX3
## 9 -0.06960103 0.1085974610 -0.007558495 0.108597461010821 sdX1
## 10 -0.02335057 0.7466508236 -0.017434723 0.746650823555587 sdO1
## 11 -0.02775621 0.7011231661 -0.019460525 0.701123166129685 sdO2
## 12 -0.09863056 0.8378067217 -0.082633342 0.837806721697553 sdC4
## 13 -0.15080156 0.8959753376 -0.135114482 0.895975337591621 sdC3
## 14 -9.13210156 0.0196000000 -0.178989191 0.0196 dfreqX1
## feature.value
## 1 sdROC=1.52721339341384
## 2 sdX2=3.52611020610991
## 3 sdLOC=2.05682868708848
## 4 sdF4=0.967756691287492
## 5 sdF3=0.969062208401303
## 6 dfreqC4=0.000633333333333333
## 7 meanROC=0.00517597111327516
## 8 sdX3=0.0596478604429808
## 9 sdX1=0.108597461010821
## 10 sdO1=0.746650823555587
## 11 sdO2=0.701123166129685
## 12 sdC4=0.837806721697553
## 13 sdC3=0.895975337591621
## 14 dfreqX1=0.0196
sum(lime_RF[["results"]]$effect)
## [1] 0.03404318
plot(lime_RF) + ggtitle("RF")
Como ya sabemos, el Random Forest no es un modelo lineal, por lo que estudiaremos los efectos si no asumimos linealidad en nuestros datos mediante los cálculos realizados con los Shap Values y así poder analizar que valores son los más importantes para conseguir una probabilidad tan alta de pertenecer a la fase ‘REM’.
set.seed(1233)
shapley_RF <- Shapley$new(predictor_RF, x.interest = high_prob_ob) %>% plot() + ggtitle("RF")
plot(shapley_RF) + ggtitle("RF")
shapley_RF$data
## feature phi phi.var feature.value
## 1 sdC3 0.04570 0.0086334444 sdC3=0.895975337591621
## 2 sdC4 0.12376 0.0833180630 sdC4=0.837806721697553
## 3 sdO1 0.15338 0.0900081774 sdO1=0.746650823555587
## 4 sdO2 0.02400 0.0026673131 sdO2=0.701123166129685
## 5 sdROC 0.01986 0.0018305257 sdROC=1.52721339341384
## 6 sdF3 0.00065 0.0004245732 sdF3=0.969062208401303
## 7 sdF4 0.00434 0.0008594186 sdF4=0.967756691287492
## 8 sdLOC 0.02895 0.0035216035 sdLOC=2.05682868708848
## 9 sdX1 0.34221 0.1425438039 sdX1=0.108597461010821
## 10 dfreqC4 -0.00065 0.0000662096 dfreqC4=0.000633333333333333
## 11 dfreqX1 0.02493 0.0078999243 dfreqX1=0.0196
## 12 sdX3 0.05524 0.0279628913 sdX3=0.0596478604429808
## 13 meanROC 0.00238 0.0001882784 meanROC=0.00517597111327516
## 14 sdX2 0.01459 0.0060695979 sdX2=3.52611020610991
library(arulesViz)
plot(head(rules(classifier_CBA),10, by = "confidence"), method="graph")
plot(head(rules(classifier_CBA),10, by = "confidence"), method="graph", engine="graphviz")
En primer lugar, destacar la diferencia existente entre las distribuciones de las desviaciones estándar, desde el Análisis Exploratorio encontramos una tendencia distinta para la fase REM que para las demás.
Mediante el modelo regresión logística seleccionamos ciertas variables que luego no son importantes para los modelos no lineales y viceversa, esto ocurre porque algunas variables (la más destacada sdC4) no se comportan de modo lineal .
Posteriormente, cuando seleccionamos las variables más importantes para el modelo de Random Forest, encontramos que la mayoría son variables que, como ya anticipábamos, representan la desviación estándar durante 30 segundos del sueño.
Estudiando la interpretabilidad de los modelos, encontramos que los modelos las clasifican como REM cuando éstas presentan valores más pequeños con respecto a la distribución total (en RF), mientras que en los modelos basados en reglas clasifican mejor la fase noREM, con la premisa de “una desviación estándar alta implica que la fase del sueño será no REM”. Llegando así a unas conclusiones similares al Random Forest.
Analizando las interacciones de las variables, encontramos también que las variables que explican en mayor proporción la variabilidad de la desviación estándar del canal O1 del encefalograma son las correspondientes a los canales de los distintos electromiogramas, por lo tanto, encontramos una interacción interesante entre el nivel muscular y el nivel cerebral para clasificar el sueño entre REM y no REM.